from plyny.table import DataTable, _IterableData

from sqlalchemy import Table
import sqlalchemy.exc as exc
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.exc import IntegrityError
from sqlalchemy import Boolean
from sqlalchemy import Float
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy.dialects.mysql import MEDIUMTEXT

Base = declarative_base()


class DuplicateError(ValueError):
    pass


class SqlData(_IterableData):
    def __init__(self, engine, table_name, *column_names):
        super(SqlData, self).__init__()
        if not column_names:
            raise ValueError('need column_names')

        self.engine = engine
        if table_name is None:
            import uuid
            table_name = str(uuid.uuid4())

        self._name = table_name

        self._table = None
        self.unique = set()
        self.column_names = column_names
        self._connection = None

    def set_keys(self, *keys):
        self.unique = set(keys)

    def _get_meta(self, values):
        if self._table is None:
            super(SqlData, self)._get_meta(values)
            columns = []
            for name, type_guess in zip(self.column_names, self._type_guesses):
                if type_guess is int:
                    type_guess = Integer
                elif type_guess is float:
                    type_guess = Float
                elif type_guess is bool:
                    type_guess = Boolean
                else:
                    type_guess = None

                if type_guess is None:
                    type_guess = MEDIUMTEXT()

                if name in self.unique:
                    c = Column(name, type_guess, unique=True)
                else:
                    c = Column(name, type_guess)

                columns.append(c)

            self._ensure_loaded(columns)

    def extend(self, other):
        map(lambda x: self.append(x), other)
        return self

    def _ensure_loaded(self, columns=None):
        if columns:
            self._table = Table(self._name, Base.metadata, *columns) 
        else:
            self._table = Table(self._name, Base.metadata, autoload=True, autoload_with=self.engine)

        self._insert = self._table.insert()
        Base.metadata.create_all(self.engine,)
        self._connection = self.engine.connect()

    def get(self, column, key):
        if self._connection is None:
            try:
                self._ensure_loaded()
            except exc.NoSuchTableError:
                return None

        exp = '%s = %s' % (column, key)
        return self._connection.execute(self._table.select(exp)).fetchone()

#        def wrap(item):
#            return dict(zip(self.column_names, item))

#        self._table.insert().execute([wrap(x) for x in other])
#        self._connection.execute(self._table.insert().values())

    def is_empty(self):
        return False  # XXX to trick DataTable into iterating over us no matter what

    def __len__(self):
        """ Row count. """
        return self._connection.execute(self._table.select().count()).fetchone()[0]

    def delete(self, key_name, key_value):
        self._connection.execute(self._table.delete('%s == %s' % (key_name, key_value)))

    def force_append(self, row):
        pass

    def append(self, row):
        if self._table is None:
            self._get_meta(row)

        try:
            self._connection.execute(self._insert.values(row))
        except IntegrityError as e:
            raise DuplicateError(e)
        return self

    def __iter__(self):
        self._ensure_loaded()
        return (tuple(x) for x in self._connection.execute(self._table.select()))


class SqlDataTable(DataTable):
    def __init__(self, engine, table_name, *column_names):
        if not column_names:
            raise ValueError('must have column_names')

        super(DataTable, self).__init__(SqlData(engine, table_name, *column_names), *column_names)

    def __iter__(self):
        return iter(self._data)

    def force_add(self, *row):
        self._data.force_append(row)

    @classmethod
    def sqlite3_memory(cls, table_name):
        return lambda *x: cls(create_engine('sqlite:///:memory:', echo=False,
            pool_size=30), table_name, *x)

    @classmethod
    def sqlite3(cls, filename, table_name):
        return lambda *x: cls(create_engine('sqlite:///%s' % filename, echo=False,
            pool_size=30), table_name, *x)

    def set_keys(self, *keys):
        self._data.unique = set(keys)

    @classmethod
    def mysql(cls, user=None, passwd=None, host=None, name=None, table_name=None):
        msg = 'mysql://%(user)s:%(passwd)s@%(host)s/%(name)s' % locals()
        return lambda *x: cls(create_engine(msg, echo=False), table_name, *x)
#
#        return lambda *x: construct(locals(), table_name, *x)

    def delete(self, key_name, key_value):
        self._data.delete(key_name, key_value)

    def get(self, column, key):
        return self._data.get(column, key)

    
if __name__ == '__main__':
    from plyny.table import DataTable

#    s = DataTable('a')
    s = SqlData(['a', 'b'], (None,))
    s.append('a')
    s.append(None)

    s.extend([1, 2, 3])

    for value in s.filtersplit(lambda x: x > 2)[0]:
        print value
